#include "tensorrt_llm/batch_manager/trtGptModelInflightBatching.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/rawEngine.h"
#include "tensorrt_llm/runtime/tllmLogger.h"
#include "tests/utils/common.h"
#include "tests/utils/engines.h"
#include "tests/utils/executorUtils.h"

#include "gtest/gtest.h"

#include <random>
#include <tuple>
#include <unordered_map>

namespace tensorrt_llm::testing
{

struct TrivialConstantDecoderTestParameters
{
    using TupleT = std::tuple<runtime::SizeType32, runtime::SizeType32, runtime::SizeType32, runtime::SizeType32,
        runtime::SizeType32, runtime::SizeType32, runtime::SizeType32, runtime::SizeType32>;
    runtime::SizeType32 randomSeed;
    runtime::SizeType32 vocabSize;
    runtime::SizeType32 maxNumTokens;
    runtime::SizeType32 maxBeamWidth;
    runtime::SizeType32 maxBatchSize;
    runtime::SizeType32 numRequests;
    runtime::SizeType32 promptLength;
    runtime::SizeType32 maxOutputLength;

    // Constructor that takes a tuple
    TrivialConstantDecoderTestParameters( // NOLINT: implicit to allow gtest to convert from tuple generated by
                                          // 'combine'
        TupleT t)
        : randomSeed(std::get<0>(t))
        , vocabSize(std::get<1>(t))
        , maxNumTokens(std::get<2>(t))
        , maxBeamWidth(std::get<3>(t))
        , maxBatchSize(std::get<4>(t))
        , numRequests(std::get<5>(t))
        , promptLength(std::get<6>(t))
        , maxOutputLength(std::get<7>(t))
    {
    }
};

template <typename TLogits>
struct DecoderTestShared
{
    static constexpr runtime::SizeType32 kNumTokensPerBlock = 64;
    static constexpr runtime::SizeType32 kKvCacheMaxTokens = 2048 * 8;

    DecoderTestShared(std::shared_ptr<runtime::TllmLogger> logger, std::mt19937 rng,
        std::shared_ptr<executor::Executor> executor, std::vector<TLogits> randomLogits)
        : logger(std::move(logger))
        , rng(rng)
        , executor(std::move(executor))
        , randomLogits(std::move(randomLogits)){};
    std::shared_ptr<runtime::TllmLogger> logger;
    std::mt19937 rng;
    std::shared_ptr<executor::Executor> executor;
    std::vector<TLogits> randomLogits;
};

template <typename TLogits>
std::unique_ptr<DecoderTestShared<TLogits>> SetupDecoderTest(TrivialConstantDecoderTestParameters const& params)
{
    auto logger = std::make_shared<runtime::TllmLogger>();
    auto rng = std::mt19937(params.randomSeed);
    auto randomLogits = tensorrt_llm::testing::randomLogits<std::mt19937, TLogits>(params.vocabSize, &rng);
    auto const decoderParameters = tensorrt_llm::testing::utils::engines::ConstantTrivialDecoderParameters<TLogits>{
        tensorrt_llm::testing::utils::engines::TrivialDecoderParameters{params.vocabSize, params.maxBatchSize,
            params.maxNumTokens, DecoderTestShared<TLogits>::kNumTokensPerBlock, params.maxBeamWidth, false},
        randomLogits};
    auto engineHostMemory
        = tensorrt_llm::testing::utils::engines::createConstantTrivialDecoder<TLogits>(decoderParameters, logger);
    auto const engine = runtime::RawEngine(engineHostMemory.release());
    auto const dtype = runtime::TRTDataType<TLogits>::value;
    auto modelConfig = runtime::ModelConfig(params.vocabSize, 1, 1, 0, 1, 1, dtype);
    modelConfig.useGptAttentionPlugin(true);
    modelConfig.setModelVariant(runtime::ModelConfig::ModelVariant::kGpt);
    modelConfig.usePackedInput(true);
    modelConfig.setKVCacheType(runtime::ModelConfig::KVCacheType::kPAGED);
    modelConfig.setMaxNumTokens(params.maxNumTokens);
    modelConfig.setMaxBatchSize(params.maxBatchSize);
    modelConfig.setMaxBeamWidth(params.maxBeamWidth);
    modelConfig.setMaxSequenceLen(params.maxNumTokens);
    modelConfig.setMaxInputLen(params.maxNumTokens);
    modelConfig.setLayerTypes({runtime::ModelConfig::LayerType::kATTENTION});
    modelConfig.setTokensPerBlock(DecoderTestShared<TLogits>::kNumTokensPerBlock);
    modelConfig.setPagedContextFMHA(true);

    auto const worldConfig = runtime::WorldConfig();
    auto kvCacheConfig = executor::KvCacheConfig{};
    kvCacheConfig.setMaxTokens(DecoderTestShared<TLogits>::kKvCacheMaxTokens);

    auto const executorConfig
        = tensorrt_llm::executor::ExecutorConfig(params.maxBeamWidth, executor::SchedulerConfig(), kvCacheConfig, true,
            true, 1, 1, executor::BatchingType::kINFLIGHT, params.maxBatchSize, params.maxNumTokens, std::nullopt,
            std::nullopt, std::nullopt, std::nullopt, false, 1, std::nullopt, executor::ExtendedRuntimePerfKnobConfig(),
            std::nullopt, 0, executor::ExecutorConfig::kDefaultMaxSeqIdleMicroseconds, std::nullopt, std::nullopt);

    auto model = std::make_shared<batch_manager::TrtGptModelInflightBatching>(
        logger, modelConfig, worldConfig, engine, false, executorConfig, false);

    return std::make_unique<DecoderTestShared<TLogits>>(
        logger, rng, std::make_shared<executor::Executor>(model, executorConfig), randomLogits);
}

template <typename TLogits>
class DecoderTest : public ::testing::Test, public ::testing::WithParamInterface<TrivialConstantDecoderTestParameters>
{
protected:
    std::unique_ptr<DecoderTestShared<TLogits>> state;

    DecoderTest()
    {
        auto const params = GetParam();
        state = SetupDecoderTest<TLogits>(params);
    }

    void runDecoderTest(TrivialConstantDecoderTestParameters const& parameters)
    {
        auto const requestTokens = createConsecutiveTokenSequence(parameters.promptLength, parameters.vocabSize, 0);
        auto requests = std::vector<executor::Request>{};
        requests.reserve(static_cast<std::size_t>(parameters.numRequests));
        for (auto i = 0; i < parameters.numRequests; i++)
        {
            requests.emplace_back(requestTokens, parameters.maxOutputLength, false, executor::SamplingConfig{},
                executor::OutputConfig{false, false, false, true, false, false});
        }
        auto const accumulatedResponses
            = runThroughRequests(*state->executor, requests, std::chrono::duration<float, std::milli>(3600000));
        ASSERT_EQ(accumulatedResponses.size(), parameters.numRequests);

        std::sort(state->randomLogits.begin(), state->randomLogits.end());
        std::reverse(state->randomLogits.begin(), state->randomLogits.end());
        for (auto const& [requestId, responses] : accumulatedResponses)
        {
            for (auto const& response : responses)
            {
                ASSERT_FALSE(response.hasError());
                auto const& tokensByBeam = response.getResult().outputTokenIds;
                ASSERT_EQ(tokensByBeam.size(), 1);
                for (auto const& tokensForBeam : tokensByBeam)
                {
                    ASSERT_EQ(tokensForBeam.size(), parameters.maxOutputLength);
                }
            }
        }
    }
};

namespace
{
constexpr runtime::SizeType32 kRandomSeed1 = 45;
auto const randomSeeds = ::testing::Values(kRandomSeed1);

constexpr runtime::SizeType32 kMinVocabSize = 16;
auto const vocabSizes = ::testing::Values(kMinVocabSize);

constexpr runtime::SizeType32 kMinMaxNumTokens = 2048;
auto const maxNumTokenses = ::testing::Values(kMinMaxNumTokens);

constexpr runtime::SizeType32 kMinBeamWidth = 1;
auto const beamWidths = ::testing::Values(kMinBeamWidth);

constexpr runtime::SizeType32 kMinMaxBatchSize = 2048;
auto const maxBatchSizes = ::testing::Values(kMinMaxBatchSize);

constexpr runtime::SizeType32 kMinNumRequests = 64;
auto const numRequestses = ::testing::Values(kMinNumRequests);

constexpr runtime::SizeType32 kMinPromptLength = 32;
auto const promptLengths = ::testing::Values(kMinPromptLength);

constexpr runtime::SizeType32 kMinMaxOutputLength = 16;
auto const maxOutputLengths = ::testing::Values(kMinMaxOutputLength);

auto const paramGenerator
    = ::testing::ConvertGenerator<TrivialConstantDecoderTestParameters::TupleT>(::testing::Combine(randomSeeds,
        vocabSizes, maxNumTokenses, beamWidths, maxBatchSizes, numRequestses, promptLengths, maxOutputLengths));
} // namespace

using DecoderFloatTest = DecoderTest<float>;

TEST_P(DecoderFloatTest, TestSizeAndValues)
{
    runDecoderTest(GetParam());
}

INSTANTIATE_TEST_SUITE_P(Float, DecoderFloatTest, paramGenerator,
    [](::testing::TestParamInfo<TrivialConstantDecoderTestParameters> const& info) -> std::string
    {
        std::stringstream nameStringStream;
        nameStringStream << "_maxBatchSize_" << info.param.maxBatchSize << "_vocabSize_" << info.param.vocabSize
                         << "_maxBeamWidth_" << info.param.maxBeamWidth << "_maxNumTokens_" << info.param.maxNumTokens
                         << "_maxOutputLength_" << info.param.maxOutputLength << "_numRequests_"
                         << info.param.numRequests << "_promptLength_" << info.param.promptLength << "_randomSeed_"
                         << info.param.randomSeed;
        return nameStringStream.str();
    });

// Helper function to test calculateCacheSizePerToken with given parameters.
std::map<runtime::SizeType32, runtime::SizeType32> calculateCacheSizePerTokenHelper(
    std::vector<runtime::SizeType32> const& maxAttentionWindowVec, runtime::SizeType32 kvFactor = 2,
    runtime::SizeType32 vocabSize = 32, runtime::SizeType32 nbLayers = 4, runtime::SizeType32 nbAttentionLayers = 4,
    runtime::SizeType32 nbRnnLayers = 0, runtime::SizeType32 nbHeads = 8, runtime::SizeType32 hiddenSize = 512,
    bool isCrossAttention = false)
{
    // Create minimal ModelConfig for testing.
    auto modelConfig = runtime::ModelConfig(
        vocabSize, nbLayers, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, nvinfer1::DataType::kFLOAT);
    modelConfig.useGptAttentionPlugin(true);
    modelConfig.setModelVariant(runtime::ModelConfig::ModelVariant::kGpt);
    modelConfig.setKVCacheType(runtime::ModelConfig::KVCacheType::kPAGED);

    auto const worldConfig = runtime::WorldConfig();

    return batch_manager::TrtGptModelInflightBatching::calculateCacheSizePerTokenForDisagg(
        modelConfig, worldConfig, maxAttentionWindowVec, isCrossAttention, kvFactor);
}

// Test for TrtGptModelInflightBatching::calculateCacheSizePerToken function with different layer types.
TEST(TrtInflightBatchingTest, CalculateCacheSizePerTokenForDisagg)
{
    // Common parameters.
    constexpr runtime::SizeType32 nbLayers = 5;
    constexpr runtime::SizeType32 hiddenSize = 512;
    constexpr runtime::SizeType32 kvFactor = 2;
    constexpr runtime::SizeType32 vocabSize = 32;
    constexpr runtime::SizeType32 nbHeads = 8;
    // Test case 1: Single attention window size - attention layers only.
    {
        std::vector<runtime::SizeType32> maxAttentionWindowVec = {128};
        constexpr runtime::SizeType32 nbAttentionLayers = 5;
        constexpr runtime::SizeType32 numBytesPerFloatElement = 4;
        constexpr runtime::SizeType32 nbRnnLayers = 0;
        auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
            nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
        EXPECT_EQ(result.size(), 1);
        EXPECT_EQ(result.at(128), nbAttentionLayers * kvFactor * hiddenSize * numBytesPerFloatElement);
    }

    // Test case 2: Multiple attention window sizes - attention layers only.
    {
        std::vector<runtime::SizeType32> maxAttentionWindowVec = {128, 256};
        constexpr runtime::SizeType32 nbAttentionLayers = 5;
        constexpr runtime::SizeType32 numBytesPerFloatElement = 4;
        constexpr runtime::SizeType32 nbRnnLayers = 0;
        auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
            nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
        EXPECT_EQ(result.size(), 2);
        auto const nbAttentionLayersIn128Window = 3;
        auto const nbAttentionLayersIn256Window = 2;
        EXPECT_EQ(result.at(128), nbAttentionLayersIn128Window * kvFactor * hiddenSize * numBytesPerFloatElement);
        EXPECT_EQ(result.at(256), nbAttentionLayersIn256Window * kvFactor * hiddenSize * numBytesPerFloatElement);
    }

    // Test case 3: Single attention window size - attention and rnn layers.
    {
        std::vector<runtime::SizeType32> maxAttentionWindowVec = {128};
        constexpr runtime::SizeType32 nbAttentionLayers = 3;
        constexpr runtime::SizeType32 numBytesPerFloatElement = 4;
        constexpr runtime::SizeType32 nbRnnLayers = 2;
        auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
            nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
        EXPECT_EQ(result.size(), 1);
        EXPECT_EQ(result.at(128), nbAttentionLayers * kvFactor * hiddenSize * numBytesPerFloatElement);
    }

    // Test case 4: Multiple attention window sizes - attention and rnn layers.
    {
        std::vector<runtime::SizeType32> maxAttentionWindowVec = {128, 256};
        constexpr runtime::SizeType32 nbAttentionLayers = 3;
        constexpr runtime::SizeType32 numBytesPerFloatElement = 4;
        constexpr runtime::SizeType32 nbRnnLayers = 2;
        auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
            nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
        EXPECT_EQ(result.size(), 2);
        auto const nbAttentionLayersIn128Window = 2;
        auto const nbAttentionLayersIn256Window = 1;
        EXPECT_EQ(result.at(128), nbAttentionLayersIn128Window * kvFactor * hiddenSize * numBytesPerFloatElement);
        EXPECT_EQ(result.at(256), nbAttentionLayersIn256Window * kvFactor * hiddenSize * numBytesPerFloatElement);
    }
}

} // namespace tensorrt_llm::testing
